"""Simplify pipelines [b59aa68fdb1f].

Revision ID: b59aa68fdb1f
Revises: 0.62.0
Create Date: 2024-07-04 14:00:32.830722

"""

from typing import Dict, Optional

import sqlalchemy as sa
import sqlmodel
from alembic import op
from sqlalchemy.dialects import mysql

# revision identifiers, used by Alembic.
revision = "b59aa68fdb1f"
down_revision = "0.62.0"
branch_labels = None
depends_on = None


def upgrade() -> None:
    """Upgrade database schema and/or data, creating a new revision."""
    # ### commands auto generated by Alembic - please adjust! ###

    with op.batch_alter_table("pipeline_deployment", schema=None) as batch_op:
        batch_op.add_column(
            sa.Column(
                "pipeline_version_hash",
                sqlmodel.sql.sqltypes.AutoString(),
                nullable=True,
            )
        )
        batch_op.add_column(
            sa.Column(
                "pipeline_spec",
                sa.String(length=16777215).with_variant(
                    mysql.MEDIUMTEXT, "mysql"
                ),
                nullable=True,
            )
        )

    connection = op.get_bind()
    meta = sa.MetaData()
    meta.reflect(
        bind=connection,
        only=(
            "pipeline",
            "pipeline_run",
            "pipeline_deployment",
            "pipeline_build",
            "schedule",
        ),
    )
    pipeline_table = sa.Table("pipeline", meta)
    pipeline_run_table = sa.Table("pipeline_run", meta)
    pipeline_deployment_table = sa.Table("pipeline_deployment", meta)
    pipeline_build_table = sa.Table("pipeline_build", meta)
    schedule_table = sa.Table("schedule", meta)

    def _migrate_pipeline_columns(
        pipeline_id: str,
        version_hash: Optional[str],
        pipeline_spec: Optional[str],
    ) -> None:
        connection.execute(
            sa.update(pipeline_deployment_table)
            .where(pipeline_deployment_table.c.pipeline_id == pipeline_id)
            .values(
                pipeline_version_hash=version_hash, pipeline_spec=pipeline_spec
            )
        )

    def _update_pipeline_fks(pipeline_id: str, replacement_id: str) -> None:
        for table in [
            pipeline_run_table,
            pipeline_deployment_table,
            pipeline_build_table,
            schedule_table,
        ]:
            connection.execute(
                sa.update(table)
                .where(table.c.pipeline_id == pipeline_id)
                .values(pipeline_id=replacement_id)
            )

    all_pipelines = connection.execute(sa.select(pipeline_table)).fetchall()
    replacement_mapping: Dict[str, str] = {}

    for pipeline in all_pipelines:
        _migrate_pipeline_columns(
            pipeline_id=pipeline.id,
            version_hash=pipeline.version_hash,
            pipeline_spec=pipeline.spec,
        )

        if replacement_id := replacement_mapping.get(pipeline.name):
            _update_pipeline_fks(
                pipeline_id=pipeline.id, replacement_id=replacement_id
            )
            connection.execute(
                sa.delete(pipeline_table).where(
                    pipeline_table.c.id == pipeline.id
                )
            )
        else:
            replacement_mapping[pipeline.name] = pipeline.id

    with op.batch_alter_table("pipeline", schema=None) as batch_op:
        batch_op.add_column(sa.Column("description", sa.TEXT(), nullable=True))
        batch_op.drop_column("spec")
        batch_op.drop_column("docstring")
        batch_op.drop_column("version_hash")
        batch_op.drop_column("version")

    # ### end Alembic commands ###


def downgrade() -> None:
    """Downgrade database schema and/or data back to the previous revision."""
    # ### commands auto generated by Alembic - please adjust! ###
    with op.batch_alter_table("pipeline_deployment", schema=None) as batch_op:
        batch_op.drop_column("pipeline_spec")
        batch_op.drop_column("pipeline_version_hash")

    with op.batch_alter_table("pipeline", schema=None) as batch_op:
        batch_op.add_column(sa.Column("version", sa.VARCHAR(), nullable=False))
        batch_op.add_column(
            sa.Column("version_hash", sa.VARCHAR(), nullable=False)
        )
        batch_op.add_column(sa.Column("docstring", sa.TEXT(), nullable=True))
        batch_op.add_column(
            sa.Column("spec", sa.VARCHAR(length=16777215), nullable=False)
        )
        batch_op.drop_column("description")

    # ### end Alembic commands ###
